{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 309
    },
    "colab_type": "code",
    "id": "AKP2Vok31tsE",
    "outputId": "f268f545-a908-4e75-f957-dad08b72c277"
   },
   "outputs": [],
   "source": [
    "#@title Install arviz\n",
    "# !pip3 install arviz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "a0disJ2Uy-Bx"
   },
   "outputs": [],
   "source": [
    "import arviz as az\n",
    "import pystan\n",
    "import os\n",
    "# os.environ['STAN_NUM_THREADS'] = \"4\"\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "YT3IVNpzrBoM"
   },
   "source": [
    "## Select model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-UJBadMvq8aM"
   },
   "outputs": [],
   "source": [
    "import MBS_epidemic_concentration_models as models\n",
    "model = models.model_linearrd_svmitigate()\n",
    "model.plotnetwork()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stanrunmodel = pystan.StanModel(model_code=model.stan)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qn5Cei_mLN0P"
   },
   "source": [
    "# Load data from JHU\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ofuI7UGsLNd7"
   },
   "outputs": [],
   "source": [
    "url_confirmed = \"https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv\"\n",
    "url_deaths = \"https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv\"\n",
    "url_recovered = \"https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_recovered_global.csv\"\n",
    "\n",
    "dfc = pd.read_csv(url_confirmed)\n",
    "dfd = pd.read_csv(url_deaths)\n",
    "dfr = pd.read_csv(url_recovered)\n",
    "\n",
    "# print(dfc)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JQDsWrTVteCu"
   },
   "source": [
    "## Make JHU ROI DF"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Enter country "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Austria,Belgium,Denmark,France,Germany,Italy,Norway,Spain,Sweden,Switzerland,United Kingdom\n",
    "roi = \"Netherlands\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "XDoHvSi-NQRQ"
   },
   "outputs": [],
   "source": [
    "\n",
    "dfc2 = dfc.loc[(dfc['Country/Region']==roi)&(pd.isnull(dfc['Province/State']))]\n",
    "dfd2 = dfd.loc[(dfd['Country/Region']==roi)&(pd.isnull(dfd['Province/State']))]\n",
    "dfr2 = dfr.loc[(dfr['Country/Region']==roi)&(pd.isnull(dfr['Province/State']))]\n",
    "\n",
    "dates_all = dfc.columns[4:].values[:-1]\n",
    "\n",
    "dates = dates_all[:]\n",
    "\n",
    "DF = df = pd.DataFrame(columns=['date','cum_cases','cum_recover','cum_deaths'])\n",
    "\n",
    "for i in range(len(dates)):\n",
    "  DF.loc[i] = pd.Series({'date':dates[i],\n",
    "                         'cum_cases':dfc2[dates[i]].values[0] - (dfr2[dates[i]].values[0] + dfd2[dates[i]].values[0]),\n",
    "                         'cum_recover':dfr2[dates[i]].values[0],\n",
    "                         'cum_deaths':dfd2[dates[i]].values[0]})\n",
    "\n",
    "DF[['daily_cases', 'daily_deaths', 'daily_recover']] = df[['cum_cases', 'cum_deaths', 'cum_recover']].diff()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pop = {}\n",
    "pop['Italy'] = 60500000\n",
    "pop['United Kingdom'] = 64400000\n",
    "pop['France'] = 66990000\n",
    "pop['Netherlands'] = 17000000\n",
    "\n",
    "mitigate = {}\n",
    "mitigate['Italy'] = '3/9/20' #approximate date\n",
    "\n",
    "t0 = np.where(DF[\"cum_cases\"].values>5)[0][0] - 1# estimated day of first exposure? Need to make this a parameter\n",
    "\n",
    "plt.subplot(1,2,1)\n",
    "plt.plot(DF[\"cum_cases\"],'bo', label=\"cases\")\n",
    "plt.plot(DF[\"cum_recover\"],'go',label=\"recovered\")\n",
    "plt.plot(DF[\"cum_deaths\"],'ks',label=\"deaths\")\n",
    "\n",
    "plt.axvline(t0,color='k', linestyle=\"dashed\", label='t0')\n",
    "\n",
    "ind = np.where(mitigate[roi]==dates)[0][0]\n",
    "plt.axvline(ind,color='b', linestyle=\"dashed\", label='mitigate')\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "\n",
    "plt.subplot(1,2,2)\n",
    "plt.plot(DF[\"daily_cases\"],'bo', label=\"cases\")\n",
    "plt.plot(DF[\"daily_recover\"],'go',label=\"recovered\")\n",
    "plt.plot(DF[\"daily_deaths\"],'ks',label=\"deaths\")\n",
    "\n",
    "plt.axvline(t0,color='k', linestyle=\"dashed\", label='t0')\n",
    "\n",
    "ind = np.where(mitigate[roi]==dates)[0][0]\n",
    "plt.axvline(ind,color='b', linestyle=\"dashed\", label='mitigate')\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "\n",
    "print(\"t0 assumed to be: day \"+str(t0))\n",
    "print(\"t0 date: \"+dates[t0])\n",
    "print(\"mitigation date: \"+dates[ind])\n",
    "print(ind)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gLSieVFFtjOb"
   },
   "source": [
    "## Format JHU ROI data for Stan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 299
    },
    "colab_type": "code",
    "id": "iG4K7TG6NTJI",
    "outputId": "12052154-8907-4d79-d85a-7823584d3aca",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.stan_data['t0'] = t0-1\n",
    "model.stan_data['tm'] = ind\n",
    "model.stan_data['ts'] = np.arange(t0,len(dates)) \n",
    "DF = DF.replace('NaN', 0)\n",
    "DF = DF.replace(-1, 0)\n",
    "model.stan_data['y'] = (DF[['daily_cases','daily_recover','daily_deaths']].to_numpy()).astype(int)[t0:,:]\n",
    "model.stan_data['n_obs'] = len(dates) - t0\n",
    "\n",
    "model.stan_data['ts_predict'] = np.arange(t0,len(dates)+365)\n",
    "model.stan_data['n_obs_predict'] = len(dates) - t0 + 365"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Enter population manually"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.stan_data['n_pop'] = pop[roi] \n",
    "model.stan_data['n_scale'] = 10000000 #use this instead of population\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Print data for Stan "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model.stan_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load England School 1978 Influenza data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #England 1978 influenza\n",
    "# cases = [0,8,26,76,225,298,258,233,189,128,150,85,14,4]\n",
    "# recovered = [0,0,0,0,9,17,105,162,176,166,150,85,47,20]\n",
    "# plt.plot(cases,'bo', label=\"cases\")\n",
    "# plt.plot(recovered,'go',label=\"recovered\")\n",
    "# pop = 763\n",
    "# model.stan_data['t0'] = 0\n",
    "# #truncate time series from t0 on (initial is t0-1)\n",
    "# model.stan_data['n_pop'] = pop \n",
    "# model.stan_data['ts'] = np.arange(1,len(cases)+1)  \n",
    "# Y = np.hstack([np.c_[cases],np.c_[recovered],np.zeros((len(cases),1))]).astype(int)\n",
    "# model.stan_data['y'] = Y\n",
    "# model.stan_data['n_obs'] = len(cases)\n",
    "\n",
    "# plt.plot(cases,'bo', label=\"cases\")\n",
    "# plt.plot(recovered,'go',label=\"recovered\")\n",
    "\n",
    "# plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GDWmF_KQ2Hg3"
   },
   "source": [
    "# Run Stan "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "oImIcxX9Lull"
   },
   "outputs": [],
   "source": [
    "# Feed in some feasible initial values to start from\n",
    "\n",
    "# init_par = [{'theta':[0.25,0.01,0.01,0.05,.02],'S0':0.5}] \n",
    "\n",
    "if model.name in [\"sir1\",\"sir2\"]:\n",
    "    def init_fun():\n",
    "        x = {'theta':\n",
    "             [0.5*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]\n",
    "            }\n",
    "        return x\n",
    "\n",
    "if model.name in [\"sicu\"]:\n",
    "    def init_fun():\n",
    "        x = {'theta':\n",
    "             [0.1*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]+\n",
    "             [0.5*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]\n",
    "            }\n",
    "        return x\n",
    "    \n",
    "if model.name in [\"sicuq\",\"sicuf\"]:\n",
    "    def init_fun():\n",
    "        x = {'theta':\n",
    "             [0.1*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]+\n",
    "             [0.5*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]\n",
    "            }\n",
    "        return x\n",
    "    \n",
    "if model.name in [\"sicrqm\"]:\n",
    "    def init_fun():\n",
    "        x = {'theta':\n",
    "             [0.1*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]+\n",
    "             [0.5*np.random.uniform()]+\n",
    "             [0.5*np.random.uniform()]+\n",
    "             [7*np.random.uniform()]+\n",
    "             [0.1*np.random.uniform()]}\n",
    "        return x\n",
    "\n",
    "    \n",
    "if model.name in [\"sicrq\",\"linearrd\"]:\n",
    "    def init_fun():\n",
    "        x = {'theta':\n",
    "             [np.random.lognormal(np.log(0.1),1)]+\n",
    "             [np.random.lognormal(np.log(0.1),1)]+\n",
    "             [np.random.lognormal(np.log(0.1),1)]+\n",
    "             [np.random.lognormal(np.log(0.1),1)]+\n",
    "             [np.random.lognormal(np.log(0.25),1)]+\n",
    "             [np.random.lognormal(np.log(0.1),1)]}\n",
    "        return x\n",
    "    \n",
    "# if model.name in [\"linearrd\"]:\n",
    "#     def init_fun():\n",
    "#         x = {'theta':\n",
    "#              [np.random.lognormal(np.log(0.1),1) for i in range(11)]}\n",
    "#         return x\n",
    "\n",
    "if model.name in [\"linearrd_mitigate\"]:\n",
    "    def init_fun():\n",
    "        x = {'theta':\n",
    "             [np.random.lognormal(np.log(0.1),1)]+\n",
    "             [np.random.lognormal(np.log(0.1),1)]+\n",
    "             [np.random.lognormal(np.log(0.1),1)]+\n",
    "             [np.random.lognormal(np.log(0.1),1)]+\n",
    "             [np.random.lognormal(np.log(0.25),1)]+\n",
    "             [np.random.lognormal(np.log(0.1),1)]+\n",
    "             [np.random.lognormal(np.log(5),5)]+\n",
    "             [np.random.lognormal(np.log(0.1),1)]}\n",
    "        return x\n",
    "\n",
    "# if model.name in [\"linearrd_mitigate\"]:\n",
    "#     def init_fun():\n",
    "#         x = {'theta':\n",
    "#              [np.random.lognormal(np.log(0.1),1)]+\n",
    "#              [np.random.lognormal(np.log(0.1),1)]+\n",
    "#              [np.random.lognormal(np.log(0.1),1)]+\n",
    "#              [np.random.lognormal(np.log(0.1),1)]+\n",
    "#              [np.random.lognormal(np.log(0.25),1)]+\n",
    "#              [np.random.lognormal(np.log(t0),5)]+\n",
    "#              [np.random.lognormal(np.log(5),5)]+\n",
    "#              [np.random.lognormal(np.log(0.1),1)]}\n",
    "#         return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fit Stan "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 429
    },
    "colab_type": "code",
    "id": "3rtwA8puqVrv",
    "outputId": "061a2610-4f20-4cb2-c8fd-10d97d845fbe"
   },
   "outputs": [],
   "source": [
    "n_chains=4\n",
    "n_warmups=2000\n",
    "n_iter=10000\n",
    "n_thin=50\n",
    "\n",
    "control = {'adapt_delta':0.95}\n",
    "fit = stanrunmodel.sampling(data = model.stan_data,init = init_fun,control=control, chains = n_chains, warmup = n_warmups, iter = n_iter, thin=n_thin, seed=13219)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2jXKxrCf3Vj2"
   },
   "outputs": [],
   "source": [
    "print(fit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#https://arviz-devs.github.io/arviz/generated/arviz.plot_density\n",
    "az.plot_density(fit,group='posterior',var_names=[\"theta\",\"R_0\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dfc2 = dfc.loc[(dfc['Country/Region']==roi)&(pd.isnull(dfc['Province/State']))]\n",
    "dfd2 = dfd.loc[(dfd['Country/Region']==roi)&(pd.isnull(dfd['Province/State']))]\n",
    "dfr2 = dfr.loc[(dfr['Country/Region']==roi)&(pd.isnull(dfr['Province/State']))]\n",
    "\n",
    "dates_all = dfc.columns[4:].values[:-1]\n",
    "\n",
    "dates = dates_all[:]\n",
    "\n",
    "DF = df = pd.DataFrame(columns=['date','cum_cases','cum_recover','cum_deaths'])\n",
    "\n",
    "for i in range(len(dates)):\n",
    "  DF.loc[i] = pd.Series({'date':dates[i],\n",
    "                         'cum_cases':dfc2[dates[i]].values[0] - (dfr2[dates[i]].values[0] + dfd2[dates[i]].values[0]),\n",
    "                         'cum_recover':dfr2[dates[i]].values[0],\n",
    "                         'cum_deaths':dfd2[dates[i]].values[0]})\n",
    "\n",
    "DF[['daily_cases', 'daily_deaths', 'daily_recover']] = df[['cum_cases', 'cum_deaths', 'cum_recover']].diff()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ms=2 \n",
    "# # x = range(len(fit.extract()['u'][-1,:,0]))\n",
    "\n",
    "# DF = df = pd.DataFrame(columns=['date','cases','recovered','deaths'])\n",
    "\n",
    "# dates_all = dfc.columns[4:].values[:-1]\n",
    "\n",
    "# dates = dates_all\n",
    "\n",
    "# for i in range(len(dates)):\n",
    "#   DF.loc[i] = pd.Series({'date':dates[i],\n",
    "#                          'cases':dfc2[dates[i]].values[0] - (dfr2[dates[i]].values[0] + dfd2[dates[i]].values[0]),\n",
    "#                          'recovered':dfr2[dates[i]].values[0],\n",
    "#                          'deaths':dfd2[dates[i]].values[0]})\n",
    "\n",
    "x = range(len(dates[t0:]))\n",
    "\n",
    "print(len(x))\n",
    "print(np.shape(DF))\n",
    "\n",
    "# if model.name in [\"sir1\",\"sir2\"]:\n",
    "#     plt.plot(x,DF[\"cases\"][t0:],'bo', label=\"cases\",ms=ms)\n",
    "#     plt.plot(x,DF[\"recovered\"][t0:] + DF[\"deaths\"][t0:],'o',color='orange',label=\"unknown\",ms=ms)\n",
    "#     labels = [\"S\",\"I\",\"U\",\"Z\"]\n",
    "#     c_ = [\"g\",\"b\",\"orange\",\"r\"]\n",
    "#     Sind = 0\n",
    "#     n = 3\n",
    "\n",
    "# if model.name in [\"sicu\",\"sicuf\",\"sicuq\"]:\n",
    "#     plt.plot(x,DF[\"cases\"][t0:],'bo', label=\"cases\",ms=ms)\n",
    "#     plt.plot(x,DF[\"recovered\"][t0:] + DF[\"deaths\"][t0:],'o',color='orange',label=\"unknown\",ms=ms)\n",
    "#     labels = ['C','U','I','S','Z']\n",
    "#     c_ = ['b','orange','r','g','m']\n",
    "#     n = 4\n",
    "\n",
    "# if model.name in [\"sicrq\",\"sicrqm\",\"linearrd\"]:\n",
    "#     plt.plot(x,DF[\"cases\"][t0:],'bo', label=\"cases\",ms=ms)\n",
    "#     plt.plot(x,DF[\"recovered\"][t0:],'o',color='orange',label=\"unknown\",ms=ms)\n",
    "#     plt.plot(x,DF[\"deaths\"][t0:],'x',color='k',label=\"unknown\",ms=ms)\n",
    "#     labels = ['C','R','D','I','S','Z']\n",
    "#     c_ = ['b','orange','k','r','g','m']\n",
    "#     Sind = 4\n",
    "#     n = 5  \n",
    "    \n",
    "if model.name in [\"linearrd\"]:\n",
    "    plt.plot(x,DF[\"daily_cases\"][t0:],'bo', label=\"cases\",ms=ms)\n",
    "    plt.plot(x,DF[\"daily_recover\"][t0:],'o',color='orange',label=\"unknown\",ms=ms)\n",
    "    plt.plot(x,DF[\"daily_deaths\"][t0:],'x',color='k',label=\"unknown\",ms=ms)\n",
    "    labels = ['I','C']\n",
    "    c_ = ['b','orange','k','r','g','m']\n",
    "#     Sind = 4\n",
    "    n = 2 \n",
    "    \n",
    "lw=4\n",
    "a = 0.5\n",
    "for i in range(n):\n",
    "    plt.plot(model.stan_data['n_scale']*fit.extract()['u'][-1,:,i],label=labels[i],lw=lw,alpha=a,color=c_[i])\n",
    "# plt.plot(model.stan_data['n_scale']*(1-fit.extract()['u_predict'][-1,:,Sind]),label=labels[-1],lw=lw,alpha=a,color=c_[-1])\n",
    "plt.legend()\n",
    "plt.ylim((0,50000))\n",
    "\n",
    "# lbC = []\n",
    "# ubC = []\n",
    "# lbR = []\n",
    "# ubR = []\n",
    "# lbD = []\n",
    "# ubD = []\n",
    "# for i in range(1,47):\n",
    "#     print(i)\n",
    "#     x = fit.stansummary(pars=[f'u[{i},1]'], probs=(0.25,0.975), digits_summary=2).split('\\n')[5]\n",
    "# #     print(x.split('\\t'))\n",
    "#     lbC.append(model.stan_data['n_scale']*float(x.split(' ')[5]))\n",
    "#     ubC.append(model.stan_data['n_scale']*float(x.split(' ')[6]))\n",
    "#     x = fit.stansummary(pars=[f'u[{i},2]'], probs=(0.25,0.975), digits_summary=2).split('\\n')[5]\n",
    "# #     print(x.split(' '))\n",
    "#     lbR.append(model.stan_data['n_scale']*float(x.split(' ')[3]))\n",
    "#     ubR.append(model.stan_data['n_scale']*float(x.split(' ')[4]))\n",
    "#     x = fit.stansummary(pars=[f'u[{i},4]'], probs=(0.25,0.975), digits_summary=2).split('\\n')[5]\n",
    "# #     print(x.split(' '))\n",
    "# #     lbD.append(model.stan_data['n_scale']*float(x.split(' ')[5]))\n",
    "# #     ubD.append(model.stan_data['n_scale']*float(x.split(' ')[6]))\n",
    "\n",
    "# plt.plot(lbC)\n",
    "# plt.plot(ubC)\n",
    "# plt.plot(lbR)\n",
    "# plt.plot(ubR)\n",
    "\n",
    "# labels = ['C','D','R','I','S','Z']\n",
    "# lw=4\n",
    "# a = 0.5\n",
    "# for i in range(5):\n",
    "#     plt.plot(model.stan_data['n_scale']*fit.extract()['u'][-1,:,i],label=labels[i],lw=lw,alpha=a)\n",
    "# plt.plot(model.stan_data['n_scale']*(1-fit.extract()['u'][-1,:,4]),label=labels[-1],lw=lw,alpha=a)\n",
    "# plt.legend()\n",
    "# plt.ylim((0,100000))\n",
    "\n",
    "# plt.subplot(1,2,2)\n",
    "# tot = DF[\"cases\"][-1:] + DF[\"recovered\"][-1:] + DF[\"deaths\"][-1:]\n",
    "\n",
    "# plt.axvline(model.stan_data['t0'],color='k', linestyle=\"dashed\")\n",
    "plt.legend()\n",
    "plt.title(roi)\n",
    "plt.figure()\n",
    "# plt.plot(fit.extract()['lp__'])\n",
    "az.plot_density(fit,group='posterior',var_names=[\"theta\",\"R_0\"])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta = az.from_pystan(posterior=fit, coords={'a': ['sigmaC','sigmaR', 'sigmaD', 'q','beta','inittheta']}, dims={'theta': ['a']})\n",
    "\n",
    "\n",
    "az.plot_pair(theta,var_names=[\"theta\"], coords={'a':['sigmaC','beta']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lbC = []\n",
    "ubC = []\n",
    "lbR = []\n",
    "ubR = []\n",
    "lbD = []\n",
    "ubD = []\n",
    "for i in range(1,47):\n",
    "#     print(i)\n",
    "    x = fit.stansummary(pars=[f'u_predict[{i},1]'], probs=(0.25,0.975), digits_summary=2).split('\\n')[5]\n",
    "#     print(x.split('\\t'))\n",
    "    lbC.append(model.stan_data['n_scale']*float(x.split(' ')[5]))\n",
    "    ubC.append(model.stan_data['n_scale']*float(x.split(' ')[6]))\n",
    "    x = fit.stansummary(pars=[f'u_predict[{i},2]'], probs=(0.25,0.975), digits_summary=2).split('\\n')[5]\n",
    "#     print(x.split(' '))\n",
    "    lbR.append(model.stan_data['n_scale']*float(x.split(' ')[3]))\n",
    "    ubR.append(model.stan_data['n_scale']*float(x.split(' ')[4]))\n",
    "#     x = fit.stansummary(pars=[f'u_predict[{i},4]'], probs=(0.25,0.975), digits_summary=2).split('\\n')[5]\n",
    "#     print(x.split(' '))\n",
    "#     lbD.append(model.stan_data['n_scale']*float(x.split(' ')[5]))\n",
    "#     ubD.append(model.stan_data['n_scale']*float(x.split(' ')[6]))\n",
    "\n",
    "plt.plot(lbC)\n",
    "plt.plot(ubC)\n",
    "plt.plot(lbR)\n",
    "plt.plot(ubR)\n",
    "# plt.plot(lbD)\n",
    "# plt.plot(ubD)\n",
    "\n",
    "# fit.stansummary(pars=['u[1,1]'], probs=(0.25,0.975), digits_summary=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "u = az.from_pystan(posterior=fit, coords={'a': ['11']}, dims={'u': ['a']})\n",
    "\n",
    "\n",
    "axes = az.plot_forest(\n",
    "    u,\n",
    "    kind=\"forestplot\",\n",
    "    var_names= ['u'],\n",
    "    coords = {'a':['0']},\n",
    "    combined=True,\n",
    "    ridgeplot_overlap=1.5,\n",
    "    colors=\"blue\",\n",
    "    figsize=(9, 4),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.shape(fit.extract()['u']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " as.data.frame(summary(fit)[['u']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "MBS_SIRmodeling_version2.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
